﻿#pragma kernel CalculateRestShapeMatching
#pragma kernel PlasticDeformation
#pragma kernel Project
#pragma kernel Apply

#include "MathUtils.cginc"
#include "AtomicDeltas.cginc"

StructuredBuffer<int> particleIndices;
StructuredBuffer<int> firstIndex;
StructuredBuffer<int> numIndices;
StructuredBuffer<int> explicitGroup;
StructuredBuffer<float> shapeMaterialParameters;

RWStructuredBuffer<float4> RW_restComs;
RWStructuredBuffer<float4> coms;
RWStructuredBuffer<quaternion> constraintOrientations;

RWStructuredBuffer<float4x4> RW_Aqq;
RWStructuredBuffer<float4x4> RW_linearTransforms;
RWStructuredBuffer<float4x4> RW_deformation;

RWStructuredBuffer<float4> RW_positions;
RWStructuredBuffer<quaternion> orientations;

StructuredBuffer<float4> restComs;
StructuredBuffer<float4x4> Aqq;
StructuredBuffer<float4x4> linearTransforms;
StructuredBuffer<float4x4> deformation;

StructuredBuffer<float4> positions;
StructuredBuffer<float4> restPositions;
StructuredBuffer<quaternion> restOrientations;
StructuredBuffer<float> invMasses;
StructuredBuffer<float> invRotationalMasses;
StructuredBuffer<float4> principalRadii;

// Variables set from the CPU
uint activeConstraintCount;
float deltaTime;
float sorFactor;

void RecalculateRestData(uint i)
{
    int k = 0;
    float maximumMass = 10000;

    // initialize rest center of mass and shape matrix:
    RW_restComs[i] = FLOAT4_ZERO;
    RW_Aqq[i] = FLOAT4X4_ZERO;

    float4 restCom = FLOAT4_ZERO;
    float4x4 _Aqq = FLOAT4X4_ZERO, _Rqq = FLOAT4X4_ZERO;

    // calculate rest center of mass, shape mass and RW_Aqq matrix.
    for (int j = 0; j < numIndices[i]; ++j)
    {
        k = particleIndices[firstIndex[i] + j];

        float mass = maximumMass;
        if (invMasses[k] > 1.0f / maximumMass)
            mass = 1.0f / invMasses[k];

        restCom += restPositions[k] * mass;

        float4x4 particleR = q_toMatrix(restOrientations[k]);
        particleR[3][3] = 0;

        _Rqq += mul(particleR,
                    mul(AsDiagonal(GetParticleInertiaTensor(principalRadii[k],invRotationalMasses[k])),
                    transpose(particleR))
                    );

        float4 restPosition = restPositions[k];
        restPosition[3] = 0;

        _Aqq += mass * multrnsp4(restPosition, restPosition);
    }
    
    if (restCom[3] < EPSILON)
        return;
    
    restCom.xyz /= restCom[3];
    RW_restComs[i] = restCom;

    restCom[3] = 0;
    _Aqq -= RW_restComs[i][3] * multrnsp4(restCom, restCom);
    _Aqq[3][3] = 1; // so that the determinant is never 0 due to all-zeros row/column.
    
    RW_Aqq[i] = Inverse(_Rqq + mul(RW_deformation[i], mul(_Aqq, transpose(RW_deformation[i]))));
    
}

[numthreads(128, 1, 1)]
void CalculateRestShapeMatching (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= activeConstraintCount) return;

    RecalculateRestData(i);
}

[numthreads(128, 1, 1)]
void Project (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;

    if (i >= activeConstraintCount) return;

    int k;
    float maximumMass = 10000;

    coms[i] = FLOAT4_ZERO;
    float4x4 Apq = FLOAT4X4_ZERO, Rpq = FLOAT4X4_ZERO;

    // calculate shape mass, center of mass, and moment matrix:
    int j;
    for (j = 0; j < numIndices[i]; ++j)
    {
        k = particleIndices[firstIndex[i] + j];

        float mass = maximumMass;
        if (invMasses[k] > 1.0f / maximumMass)
            mass = 1.0f / invMasses[k];

        coms[i] += positions[k] * mass;

        float4x4 particleR = q_toMatrix(orientations[k]);
        float4x4 particleRT = q_toMatrix(restOrientations[k]);
        particleR[3][3] = 0;
        particleRT[3][3] = 0;

        Rpq += mul(particleR,
                   mul(AsDiagonal(GetParticleInertiaTensor(principalRadii[k],invRotationalMasses[k])),
                   transpose(particleRT))
                   );

        float4 restPosition = restPositions[k];
        restPosition[3] = 0;

        Apq += mass * multrnsp4(positions[k], restPosition);
    }

    if (restComs[i][3] < EPSILON)
        return;

    coms[i] /= restComs[i][3];

    // subtract global shape moment:
    float4 restCom = restComs[i];
    restCom[3] = 0;

    Apq -= restComs[i][3] * multrnsp4(coms[i], restCom);

    // calculate optimal transform including plastic deformation:
    float4x4 Apq_def = Rpq + mul(Apq, transpose(deformation[i]));
    Apq_def[3][3] = 1;

    // reconstruct full best-matching linear transform:
    RW_linearTransforms[i] = mul(Apq_def, Aqq[i]);

    // extract rotation from transform matrix, using warmstarting and few iterations:
    constraintOrientations[i] = ExtractRotation(Apq_def, constraintOrientations[i], 5);
   
    // calculate particle orientations:
    if (explicitGroup[i] > 0)
    {
        // if the group is explicit, set the orientation for all particles:
        for (int j = 0; j < numIndices[i]; ++j)
        {
            k = particleIndices[firstIndex[i] + j];
            orientations[k] = qmul(constraintOrientations[i], restOrientations[k]);
        }
    }
    else
    {
        // set orientation of center particle only:
        int centerIndex = particleIndices[firstIndex[i]];
        orientations[centerIndex] = qmul(constraintOrientations[i], restOrientations[centerIndex]);
    }

    // finally, obtain rotation matrix:
    float4x4 R = q_toMatrix(constraintOrientations[i]);
    R[3][3] = 0;
    float4x4 transform = mul(R,deformation[i]);

    // calculate and accumulate particle goal positions:
    float4 goal, delta;
    for (j = 0; j < numIndices[i]; ++j)
    {
        k = particleIndices[firstIndex[i] + j];
        goal = coms[i] + mul(transform, restPositions[k] - restComs[i]);
        delta = (goal - positions[k]) * shapeMaterialParameters[i * 5];

        AddPositionDelta(k, delta);
    }
}

[numthreads(128, 1, 1)]
void PlasticDeformation (uint3 id : SV_DispatchThreadID)
{
    unsigned int i = id.x;
    if (i >= activeConstraintCount) return;

    // get plastic deformation parameters:
    float plastic_yield = shapeMaterialParameters[i * 5 + 1];
    float plastic_creep = shapeMaterialParameters[i * 5 + 2];
    float plastic_recovery = shapeMaterialParameters[i * 5 + 3];
    float max_deform = shapeMaterialParameters[i * 5 + 4];

    // if we are allowed to absorb deformation:
    if (plastic_creep > 0)
    {
        //obtain rotation matrix:
        float4x4 R = q_toMatrix(constraintOrientations[i]);
        R[3][3] = 1;

        // get scale matrix (A = RS so S = Rt * A) and its deviation from the identity matrix:
        float4x4 deform_matrix = mul(transpose(R), linearTransforms[i]) - FLOAT4X4_IDENTITY;

        // if the amount of deformation exceeds the yield threshold:
        float norm = FrobeniusNorm(deform_matrix);
        if (norm > plastic_yield)
        {
            // deform the shape permanently:
            RW_deformation[i] = mul(FLOAT4X4_IDENTITY + plastic_creep * deform_matrix, RW_deformation[i]);

            // clamp deformation so that it does not exceed a percentage;
            deform_matrix = RW_deformation[i] - FLOAT4X4_IDENTITY;
            norm = FrobeniusNorm(deform_matrix);
            if (norm > max_deform)
            {
                RW_deformation[i] = FLOAT4X4_IDENTITY + max_deform * deform_matrix / norm;
            }

            // if we cannot recover from plastic deformation, recalculate rest shape now:
            if (plastic_recovery == 0)
                RecalculateRestData(i);
        }
    }

    // if we can recover from plastic deformation, lerp towards non-deformed shape and recalculate rest shape:
    if (plastic_recovery > 0)
    {
        RW_deformation[i] += (FLOAT4X4_IDENTITY - RW_deformation[i]) * min(plastic_recovery * deltaTime, 1.0f);
        RecalculateRestData(i);
    }
}

[numthreads(128, 1, 1)]
void Apply (uint3 id : SV_DispatchThreadID) 
{
    unsigned int i = id.x;
   
    if (i >= activeConstraintCount) return;

    int first = firstIndex[i];
    int last = first + numIndices[i];

    for (int k = first; k < last; ++k)
    {
        int p = particleIndices[k];
        ApplyPositionDelta(RW_positions, p, sorFactor);
    }
}